{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CycleGAN train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from models.cycleGAN import CycleGAN\n",
    "from utils.loaders import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# run params\n",
    "SECTION = 'paint'\n",
    "RUN_ID = '0001'\n",
    "DATA_NAME = 'apple2orange'\n",
    "RUN_FOLDER = 'run/{}/'.format(SECTION)\n",
    "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n",
    "\n",
    "if not os.path.exists(RUN_FOLDER):\n",
    "    os.mkdir(RUN_FOLDER)\n",
    "    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))\n",
    "    os.mkdir(os.path.join(RUN_FOLDER, 'images'))\n",
    "    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))\n",
    "\n",
    "mode =  'build' # 'build' # "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_SIZE = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "data_loader = DataLoader(dataset_name=DATA_NAME, img_res=(IMAGE_SIZE, IMAGE_SIZE))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gan = CycleGAN(\n",
    "    input_dim = (IMAGE_SIZE,IMAGE_SIZE,3)\n",
    "    ,learning_rate = 0.0002\n",
    "    , buffer_max_length = 50\n",
    "    , lambda_validation = 1\n",
    "    , lambda_reconstr = 10\n",
    "    , lambda_id = 2\n",
    "    , generator_type = 'unet'\n",
    "    , gen_n_filters = 32\n",
    "    , disc_n_filters = 32\n",
    "    )\n",
    "\n",
    "if mode == 'build':\n",
    "    gan.save(RUN_FOLDER)\n",
    "else:\n",
    "    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gan.g_BA.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gan.g_AB.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gan.d_A.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gan.d_B.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 1\n",
    "EPOCHS = 200\n",
    "PRINT_EVERY_N_BATCHES = 10\n",
    "\n",
    "TEST_A_FILE = 'n07740461_14740.jpg'\n",
    "TEST_B_FILE = 'n07749192_4241.jpg'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "gan.train(data_loader\n",
    "        , run_folder = RUN_FOLDER\n",
    "        , epochs=EPOCHS\n",
    "        , test_A_file = TEST_A_FILE\n",
    "        , test_B_file = TEST_B_FILE\n",
    "        , batch_size=BATCH_SIZE\n",
    "        , sample_interval=PRINT_EVERY_N_BATCHES)\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20,10))\n",
    "\n",
    "plt.plot([x[1] for x in gan.g_losses], color='green', linewidth=0.1) #DISCRIM LOSS\n",
    "# plt.plot([x[2] for x in gan.g_losses], color='orange', linewidth=0.1)\n",
    "plt.plot([x[3] for x in gan.g_losses], color='blue', linewidth=0.1) #CYCLE LOSS\n",
    "# plt.plot([x[4] for x in gan.g_losses], color='orange', linewidth=0.25)\n",
    "plt.plot([x[5] for x in gan.g_losses], color='red', linewidth=0.25) #ID LOSS\n",
    "# plt.plot([x[6] for x in gan.g_losses], color='orange', linewidth=0.25)\n",
    "\n",
    "plt.plot([x[0] for x in gan.g_losses], color='black', linewidth=0.25)\n",
    "\n",
    "# plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)\n",
    "\n",
    "plt.xlabel('batch', fontsize=18)\n",
    "plt.ylabel('loss', fontsize=16)\n",
    "\n",
    "plt.ylim(0, 5)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gdl",
   "language": "python",
   "name": "gdl"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
